{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "import numpy as np\n",
    "from detectron2.config import get_cfg\n",
    "from detectron2 import model_zoo\n",
    "from detectron2.engine import DefaultTrainer\n",
    "import pickle\n",
    "from fvcore.common.file_io import PathManager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Register my amodal datasets \n",
    "from detectron2.data.datasets import register_coco_instances\n",
    "from detectron2.data import MetadataCatalog\n",
    "register_coco_instances(\"amodal_train\", {}, \"datasets/coco/annotations/COCO_amodal_train2014_with_classes_poly.json\", \"datasets/coco/train2014\")\n",
    "# Prepare test datasets \n",
    "from detectron2.data.datasets import register_coco_instances\n",
    "from detectron2.data import MetadataCatalog\n",
    "register_coco_instances(\"amodal_val\", {}, \"datasets/coco/annotations/COCO_amodal_val2014_with_classes_poly.json\", \"datasets/coco/val2014\")\n",
    "from detectron2.data import DatasetCatalog\n",
    "dataset_dicts = DatasetCatalog.get(\"amodal_train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m[03/02 00:43:07 d2.engine.defaults]: \u001b[0mModel:\n",
      "GeneralizedRCNN(\n",
      "  (backbone): FPN(\n",
      "    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (top_block): LastLevelMaxPool()\n",
      "    (bottom_up): ResNet(\n",
      "      (stem): BasicStem(\n",
      "        (conv1): Conv2d(\n",
      "          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n",
      "          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "        )\n",
      "      )\n",
      "      (res2): Sequential(\n",
      "        (0): BottleneckBlock(\n",
      "          (shortcut): Conv2d(\n",
      "            64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv1): Conv2d(\n",
      "            64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (1): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (2): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (res3): Sequential(\n",
      "        (0): BottleneckBlock(\n",
      "          (shortcut): Conv2d(\n",
      "            256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv1): Conv2d(\n",
      "            256, 128, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (1): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (2): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (3): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (res4): Sequential(\n",
      "        (0): BottleneckBlock(\n",
      "          (shortcut): Conv2d(\n",
      "            512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "          (conv1): Conv2d(\n",
      "            512, 256, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (1): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (2): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (3): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (4): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (5): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (res5): Sequential(\n",
      "        (0): BottleneckBlock(\n",
      "          (shortcut): Conv2d(\n",
      "            1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05)\n",
      "          )\n",
      "          (conv1): Conv2d(\n",
      "            1024, 512, kernel_size=(1, 1), stride=(2, 2), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (1): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "        (2): BottleneckBlock(\n",
      "          (conv1): Conv2d(\n",
      "            2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv2): Conv2d(\n",
      "            512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)\n",
      "          )\n",
      "          (conv3): Conv2d(\n",
      "            512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False\n",
      "            (norm): FrozenBatchNorm2d(num_features=2048, eps=1e-05)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (proposal_generator): RPN(\n",
      "    (anchor_generator): DefaultAnchorGenerator(\n",
      "      (cell_anchors): BufferList()\n",
      "    )\n",
      "    (rpn_head): StandardRPNHead(\n",
      "      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (objectness_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (anchor_deltas): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "  )\n",
      "  (roi_heads): StandardROIHeads(\n",
      "    (box_pooler): ROIPooler(\n",
      "      (level_poolers): ModuleList(\n",
      "        (0): ROIAlign(output_size=(7, 7), spatial_scale=0.25, sampling_ratio=0, aligned=True)\n",
      "        (1): ROIAlign(output_size=(7, 7), spatial_scale=0.125, sampling_ratio=0, aligned=True)\n",
      "        (2): ROIAlign(output_size=(7, 7), spatial_scale=0.0625, sampling_ratio=0, aligned=True)\n",
      "        (3): ROIAlign(output_size=(7, 7), spatial_scale=0.03125, sampling_ratio=0, aligned=True)\n",
      "      )\n",
      "    )\n",
      "    (box_head): FastRCNNConvFCHead(\n",
      "      (fc1): Linear(in_features=12544, out_features=1024, bias=True)\n",
      "      (fc2): Linear(in_features=1024, out_features=1024, bias=True)\n",
      "    )\n",
      "    (box_predictor): FastRCNNOutputLayers(\n",
      "      (cls_score): Linear(in_features=1024, out_features=81, bias=True)\n",
      "      (bbox_pred): Linear(in_features=1024, out_features=320, bias=True)\n",
      "    )\n",
      "    (mask_pooler): ROIPooler(\n",
      "      (level_poolers): ModuleList(\n",
      "        (0): ROIAlign(output_size=(14, 14), spatial_scale=0.25, sampling_ratio=0, aligned=True)\n",
      "        (1): ROIAlign(output_size=(14, 14), spatial_scale=0.125, sampling_ratio=0, aligned=True)\n",
      "        (2): ROIAlign(output_size=(14, 14), spatial_scale=0.0625, sampling_ratio=0, aligned=True)\n",
      "        (3): ROIAlign(output_size=(14, 14), spatial_scale=0.03125, sampling_ratio=0, aligned=True)\n",
      "      )\n",
      "    )\n",
      "    (mask_head): MaskRCNNConvUpsampleHead(\n",
      "      (mask_fcn1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (mask_fcn2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (mask_fcn3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (mask_fcn4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (deconv): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))\n",
      "      (predictor): Conv2d(256, 80, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[5m\u001b[31mWARNING\u001b[0m \u001b[32m[03/02 00:43:08 d2.data.datasets.coco]: \u001b[0m\n",
      "Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.\n",
      "\n",
      "\u001b[32m[03/02 00:43:08 d2.data.datasets.coco]: \u001b[0mLoaded 2276 images in COCO format from datasets/coco/annotations/COCO_amodal_train2014_with_classes_poly.json\n",
      "\u001b[32m[03/02 00:43:08 d2.data.build]: \u001b[0mRemoved 0 images with no usable annotations. 2276 images left.\n",
      "\u001b[32m[03/02 00:43:08 d2.data.build]: \u001b[0mDistribution of instances among all 80 categories:\n",
      "\u001b[36m|   category    | #instances   |   category   | #instances   |   category    | #instances   |\n",
      "|:-------------:|:-------------|:------------:|:-------------|:-------------:|:-------------|\n",
      "|    person     | 2311         |   bicycle    | 21           |      car      | 279          |\n",
      "|  motorcycle   | 47           |   airplane   | 65           |      bus      | 92           |\n",
      "|     train     | 60           |    truck     | 115          |     boat      | 79           |\n",
      "| traffic light | 13           | fire hydrant | 35           |   stop sign   | 12           |\n",
      "| parking meter | 10           |    bench     | 48           |     bird      | 57           |\n",
      "|      cat      | 94           |     dog      | 100          |     horse     | 57           |\n",
      "|     sheep     | 102          |     cow      | 91           |   elephant    | 80           |\n",
      "|     bear      | 28           |    zebra     | 84           |    giraffe    | 90           |\n",
      "|   backpack    | 14           |   umbrella   | 74           |    handbag    | 26           |\n",
      "|      tie      | 9            |   suitcase   | 77           |    frisbee    | 29           |\n",
      "|     skis      | 13           |  snowboard   | 21           |  sports ball  | 36           |\n",
      "|     kite      | 49           | baseball bat | 33           | baseball gl.. | 5            |\n",
      "|  skateboard   | 41           |  surfboard   | 37           | tennis racket | 49           |\n",
      "|    bottle     | 222          |  wine glass  | 70           |      cup      | 215          |\n",
      "|     fork      | 25           |    knife     | 69           |     spoon     | 33           |\n",
      "|     bowl      | 161          |    banana    | 26           |     apple     | 34           |\n",
      "|   sandwich    | 33           |    orange    | 31           |   broccoli    | 9            |\n",
      "|    carrot     | 19           |   hot dog    | 27           |     pizza     | 60           |\n",
      "|     donut     | 70           |     cake     | 76           |     chair     | 195          |\n",
      "|     couch     | 55           | potted plant | 12           |      bed      | 26           |\n",
      "| dining table  | 34           |    toilet    | 80           |      tv       | 91           |\n",
      "|    laptop     | 74           |    mouse     | 25           |    remote     | 28           |\n",
      "|   keyboard    | 36           |  cell phone  | 38           |   microwave   | 33           |\n",
      "|     oven      | 23           |   toaster    | 5            |     sink      | 53           |\n",
      "| refrigerator  | 63           |     book     | 59           |     clock     | 39           |\n",
      "|     vase      | 45           |   scissors   | 8            |  teddy bear   | 68           |\n",
      "|  hair drier   | 0            |  toothbrush  | 10           |               |              |\n",
      "|     total     | 6763         |              |              |               |              |\u001b[0m\n",
      "\u001b[32m[03/02 00:43:08 d2.data.detection_utils]: \u001b[0mTransformGens used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]\n",
      "\u001b[32m[03/02 00:43:08 d2.data.build]: \u001b[0mUsing training sampler TrainingSampler\n"
     ]
    }
   ],
   "source": [
    "cfg = get_cfg()\n",
    "cfg.merge_from_file(model_zoo.get_config_file(\"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml\"))\n",
    "cfg.DATALOADER.NUM_WORKERS = 2\n",
    "cfg.DATASETS.TRAIN = (\"amodal_train\",)\n",
    "cfg.DATASETS.TEST = (\"amodal_val\",)\n",
    "cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml\")  # Let training initialize from model zoo\n",
    "cfg.SOLVER.IMS_PER_BATCH = 2\n",
    "cfg.SOLVER.BASE_LR = 0.0005  # pick a good LR\n",
    "cfg.SOLVER.MAX_ITER = 1000    # 300 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset\n",
    "os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
    "cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model\n",
    "trainer = DefaultTrainer(cfg) \n",
    "# import pdb; pdb.set_trace()\n",
    "trainer.resume_or_load(resume=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/waiyu/.torch/fvcore_cache/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl\n"
     ]
    }
   ],
   "source": [
    "pretrained_path = \"~/.torch/fvcore_cache/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl\"\n",
    "DETECTRON_PATH = os.path.expanduser(pretrained_path)\n",
    "print(DETECTRON_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with PathManager.open(DETECTRON_PATH, \"rb\") as f:\n",
    "    data = pickle.load(f, encoding=\"latin1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = dict(data['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "d['roi_heads.visible_mask_head.visible_mask_fcn1.weight'] = d['roi_heads.mask_head.mask_fcn1.weight']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn1.bias'] = d['roi_heads.mask_head.mask_fcn1.bias']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn2.weight'] = d['roi_heads.mask_head.mask_fcn2.weight']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn2.bias'] = d['roi_heads.mask_head.mask_fcn2.bias']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn3.weight'] = d['roi_heads.mask_head.mask_fcn3.weight']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn3.bias'] = d['roi_heads.mask_head.mask_fcn3.bias']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn4.weight'] = d['roi_heads.mask_head.mask_fcn4.weight']\n",
    "d['roi_heads.visible_mask_head.visible_mask_fcn4.bias'] = d['roi_heads.mask_head.mask_fcn4.bias']\n",
    "d['roi_heads.visible_mask_head.deconv.weight'] = d['roi_heads.mask_head.deconv.weight']\n",
    "d['roi_heads.visible_mask_head.deconv.bias'] = d['roi_heads.mask_head.deconv.bias']\n",
    "d['roi_heads.visible_mask_head.predictor.weight'] = d['roi_heads.mask_head.predictor.weight']\n",
    "d['roi_heads.visible_mask_head.predictor.bias'] = d['roi_heads.mask_head.predictor.bias']\n",
    "# Amodal \n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn1.weight'] = d['roi_heads.mask_head.mask_fcn1.weight']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn1.bias'] = d['roi_heads.mask_head.mask_fcn1.bias']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn2.weight'] = d['roi_heads.mask_head.mask_fcn2.weight']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn2.bias'] = d['roi_heads.mask_head.mask_fcn2.bias']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn3.weight'] = d['roi_heads.mask_head.mask_fcn3.weight']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn3.bias'] = d['roi_heads.mask_head.mask_fcn3.bias']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn4.weight'] = d['roi_heads.mask_head.mask_fcn4.weight']\n",
    "d['roi_heads.amodal_mask_head.amodal_mask_fcn4.bias'] = d['roi_heads.mask_head.mask_fcn4.bias']\n",
    "d['roi_heads.amodal_mask_head.deconv.weight'] = d['roi_heads.mask_head.deconv.weight']\n",
    "d['roi_heads.amodal_mask_head.deconv.bias'] = d['roi_heads.mask_head.deconv.bias']\n",
    "d['roi_heads.amodal_mask_head.predictor.weight'] = d['roi_heads.mask_head.predictor.weight']\n",
    "d['roi_heads.amodal_mask_head.predictor.bias'] = d['roi_heads.mask_head.predictor.bias']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "newdict = data\n",
    "newdict['model'] = d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(256, 256, 3, 3)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "newdict['model']['roi_heads.amodal_mask_head.amodal_mask_fcn1.weight'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saved to ./pretrained_model/amodal_mask_orcnn_R_50_FPN_3x.pth.\n"
     ]
    }
   ],
   "source": [
    "!mkdir pretrained_model\n",
    "torch.save(newdict, \"./pretrained_model/amodal_mask_orcnn_R_50_FPN_3x.pth\")\n",
    "print('saved to {}.'.format(\"./pretrained_model/amodal_mask_orcnn_R_50_FPN_3x.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
